import argparse
import torch
import numpy as np
import torch.optim as optim
from utils import KpiReader
from models import StackedVAGT
from logger import Logger
from metrics import All_Metrics
from load_ett import _get_data
from logger import get_logger

class Tester(object):
    def __init__(self, model,  testloader, log_path='log_tester', log_file='loss', device=torch.device('cpu'),
                 learning_rate=0.0002, nsamples=None, sample_path=None, checkpoints=None):
        self.model = model
        self.model.to(device)
        self.device = device
        # self.test = test
        self.testloader = testloader
        self.log_path = log_path
        self.log_file = log_file
        self.learning_rate = learning_rate
        self.nsamples = nsamples
        self.sample_path = sample_path
        self.checkpoints = checkpoints
        self.start_epoch = 0
        self.optimizer = optim.Adam(self.model.parameters(), self.learning_rate)
        self.epoch_losses = []
        # self.logger = Logger(self.log_path, self.log_file)
        self.loss = {}
        self.logger = get_logger(self.log_path, name=None, debug=False)

    def load_checkpoint(self, start_ep):
        try:
            print("Loading Chechpoint from ' {} '".format(self.checkpoints + '_epochs{}.pth'.format(start_ep)))
            checkpoint = torch.load(self.checkpoints + '_epochs{}.pth'.format(start_ep))
            self.start_epoch = checkpoint['epoch']
            self.model.beta = checkpoint['beta']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.epoch_losses = checkpoint['losses']
            print("Resuming Training From Epoch {}".format(self.start_epoch))

        except:
            print("No Checkpoint Exists At '{}', Starting Fresh Training".format(
                self.checkpoints + '_epochs{}.pth'.format(start_ep)))
            self.start_epoch = 0

    def model_test_v2(self):
        self.model.eval()
        y_pred = []
        y_true = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.testloader):
                data, label = data,target
                # print(batch_idx)
                data = data.unsqueeze(3).cuda()
                label = label.unsqueeze(3).cuda()
                # print('d',data)
                # print('l',label)

                z_posterior_forward_list, \
                                z_mean_posterior_forward_list, \
                                z_logvar_posterior_forward_list, \
                                z_mean_prior_forward_list, \
                                z_logvar_prior_forward_list, \
                                x_mu_list, \
                                x_logsigma_list, output = self.forward_test(data)
                # z_posterior_forward_list, \
                # z_mean_posterior_forward_list, \
                # z_logvar_posterior_forward_list, output = self.forward_test(data)

                y_true.append(label)
                y_pred.append(output)

        y_true =torch.cat(y_true, dim=0)
        y_pred = torch.cat(y_pred, dim=0)

        # print('fffffffffffffffffffffffffffffffffffffff')
        # print(t)
        # print(y_true.shape, y_pred.shape)
        print('testing!!!!!!!!')
        for t in range(y_true.shape[1]):

            mae, rmse, mape, mse, _ = All_Metrics(y_pred[:, t, ...], y_true[:, t, ...],
                                                None, 0.)
            self.logger.info("Horizon {:02d}, MAE: {:.3f}, RMSE: {:.3f}, MAPE: {:.4f}%, MSE:{:.3f}".format(
                t + 1, mae, rmse, mape * 100,mse))
        mae, rmse, mape, mse, _ = All_Metrics(y_pred, y_true, None, 0.)
        self.logger.info("Average Horizon, MAE: {:.3f}, RMSE: {:.3f}, MAPE: {:.4f}% ,MSE:{:.3f}".format(
            mae, rmse, mape * 100,mse))

        print("Testing is complete!")

    def forward_test(self, data):
        with torch.no_grad():
            z_posterior_forward_list, \
            z_mean_posterior_forward_list, \
            z_logvar_posterior_forward_list, \
            z_mean_prior_forward_list, \
            z_logvar_prior_forward_list, \
            x_mu_list, \
            x_logsigma_list, h_out = self.model(data)
            return z_posterior_forward_list, z_mean_posterior_forward_list, z_logvar_posterior_forward_list, \
                   z_mean_prior_forward_list, z_logvar_prior_forward_list, x_mu_list, x_logsigma_list, h_out

    # def forward_test(self, data):
    #     with torch.no_grad():
    #         z_posterior_forward_list, \
    #         z_mean_posterior_forward_list, \
    #         z_logvar_posterior_forward_list, h_out = self.model(data)
    #         return z_posterior_forward_list, z_mean_posterior_forward_list, z_logvar_posterior_forward_list, h_out

    def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logsigma):
        llh = -0.5 * torch.sum(torch.pow(((x.float() - recon_x_mu.float()) / torch.exp(recon_x_logsigma.float())),
                                         2) + 2 * recon_x_logsigma.float() + np.log(np.pi * 2))
        return llh


def main():
    import os
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

    parser = argparse.ArgumentParser()
    # GPU
    parser.add_argument('--gpu_id', type=int, default=0)
    # Dataset options
    parser.add_argument('--dataset', default='ETTh1', type=str)
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--root_path', type=str, default='./data/ETT/',
                        help='root path of the data file')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--seq_len', type=int, default=48, help='input sequence length of Informer encoder')
    parser.add_argument('--pred_len', type=int, default=24, help='prediction sequence length')
    parser.add_argument('--features', type=str, default='M',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--cols', type=str, nargs='+', help='certain cols from the data files as the input features')
    parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
    parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)
    parser.add_argument('--batch_size', type=int, default=300)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--x_dim', type=int, default=7)
    parser.add_argument('--win_len', type=int, default=48,help='the same as seq_len')
    # Model options for VAGT
    parser.add_argument('--z_dim', type=int, default=25)
    parser.add_argument('--h_dim', type=int, default=50)
    parser.add_argument('--n_head', type=int, default=8)
    parser.add_argument('--layer_xz', type=int, default=2)
    parser.add_argument('--layer_h', type=int, default=3)
    parser.add_argument('--q_len', type=int, default=1, help='for conv1D padding in Transformer')
    parser.add_argument('--embd_h', type=int, default=128)
    parser.add_argument('--embd_s', type=int, default=256)
    parser.add_argument('--vocab_len', type=int, default=256)
    # Training options for VAGT
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--learning_rate', type=float, default=0.0001)
    parser.add_argument('--beta', type=float, default=0.0)
    parser.add_argument('--max_beta', type=float, default=1.0)
    parser.add_argument('--anneal_rate', type=float, default=0.05)

    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--start_epoch', type=int, default=0)
    parser.add_argument('--checkpoints_interval', type=int, default=10)
    parser.add_argument('--checkpoints_path', type=str, default='model/ettm148')
    parser.add_argument('--checkpoints_file', type=str, default='')
    parser.add_argument('--log_path', type=str, default='log_tester/ettm148')
    parser.add_argument('--log_file', type=str, default='')






    args = parser.parse_args()

    # Set up GPU
    if torch.cuda.is_available() and args.gpu_id >= 0:
        device = torch.device('cuda:%d' % args.gpu_id)
    else:
        device = torch.device('cpu')

    # For config checking
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    if not os.path.exists(args.checkpoints_path):
        os.makedirs(args.checkpoints_path)

    # TODO Saving path names, for updating later...
    if args.checkpoints_file == '':
        args.checkpoints_file = 'x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_' \
                                'win_len-{}_q_len-{}_vocab_len-{}'.format(args.x_dim, args.z_dim, args.h_dim,
                                                                          args.layer_xz, args.layer_h, args.embd_h,
                                                                          args.n_head, args.win_len, args.q_len,
                                                                          args.vocab_len)
    if args.log_file == '':
        args.log_file = 'x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_win_len-{}_' \
                        'q_len-{}_vocab_len-{}_epochs-{}_loss'.format(args.x_dim, args.z_dim, args.h_dim, args.layer_xz,
                                                       args.layer_h, args.embd_h, args.n_head, args.win_len,
                                                       args.q_len, args.vocab_len, args.epochs)

    # For training dataset
    # kpi_value_test = KpiReader(args.dataset_path)
    # test_loader = torch.utils.data.DataLoader(kpi_value_test, batch_size=args.batch_size,
    #                                           shuffle=True, num_workers=args.num_workers)
    # train_loader, val_loader, test_loader, scaler = get_dataloader(args,
    #                                                                normalizer=args.normalizer,
    #                                                                tod=args.tod, dow=False,
    #                                                                weather=False, single=False)
    # For models init
    data_parser = {
        'ETTh1': {'data': 'ETTh1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'ETTh2': {'data': 'ETTh2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'ETTm1': {'data': 'ETTm1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},

    }
    if args.dataset in data_parser.keys():
        data_info = data_parser[args.dataset]
        args.data_path = data_info['data']
        args.target = data_info['T']
    args.detail_freq = args.freq
    args.freq = args.freq[-1:]
    _,train_loader = _get_data(args,'train')
    _, val_loader = _get_data(args, 'val')
    _, test_loader = _get_data(args, 'test')
    stackedvagt = StackedVAGT(layer_xz=args.layer_xz, layer_h=args.layer_h, n_head=args.n_head, x_dim=args.x_dim,
                              z_dim=args.z_dim, h_dim=args.h_dim, embd_h=args.embd_h, embd_s=args.embd_s,
                              beta=args.beta, q_len=args.q_len, vocab_len=args.vocab_len, win_len=args.seq_len,horizon=args.pred_len,
                              dropout=args.dropout, anneal_rate=args.anneal_rate, max_beta=args.max_beta,
                              device=device).to(device)
    names = []
    for name, parameters in stackedvagt.named_parameters():
        names.append(name)
        print(name, ':', parameters, parameters.size())
    # Start train
    tester = Tester(stackedvagt, test_loader, log_path=args.log_path,
                    log_file=args.log_file, learning_rate=args.learning_rate, device=device,
                    checkpoints=os.path.join(args.checkpoints_path, args.checkpoints_file),
                    nsamples=None, sample_path=None)
    tester.load_checkpoint(args.epochs)

    #inference test
    tester.model_test_v2()
    # tester.logger.log_tester()
    #
    # tester.logger.anomaly_score_plot_llh_x(y_range=[-50, 10])
    # tester.logger.anomaly_score_plot_llh_xz(y_range=[-50, 30])
    # tester.logger.anomaly_score_plot_llh_z(y_range=[-50, 10])
    # tester.logger.anomaly_score_plot_llh_x_verified(y_range=[-50, 10])
    # tester.logger.anomaly_score_plot_llh_xz_verified(y_range=[-50, 30])
    # tester.logger.anomaly_score_plot_llh_z_verified(y_range=[-50, 10])
    # tester.logger._plot_z(1)


if __name__ == "__main__":
    import warnings
    warnings.filterwarnings('ignore')
    main()